# Copyright (c) 2023 Copyright holder of the paper "Revisiting Image Classifier Training for Improved Certified Robust Defense against Adversarial Patches" submitted to TMLR for review

# All rights reserved.

import os
import torch
import torch.nn.functional as F
import torch.utils.data
from torchvision import transforms
import torchvision.datasets as torchdatasets
from torchvision.transforms import InterpolationMode
from build import Cutout

# mean = [0.485, 0.456, 0.406]
# std = [0.229, 0.224, 0.225]
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]


def normalize(tensor):
    channel1 = (tensor[:, [0], :, :] - mean[0]) / std[0]
    channel2 = (tensor[:, [1], :, :] - mean[1]) / std[1]
    channel3 = (tensor[:, [2], :, :] - mean[2]) / std[2]
    tensor = torch.cat((channel1, channel2, channel3), dim=1)
    return tensor


def get_num_classes(dataset):
    num_classes = 1000
    if dataset == 'cifar100':
        num_classes = 100
    elif dataset in ['cifar10', 'svhn', 'imagenette']:
        num_classes = 10
    return num_classes


def get_dataloaders(dataset, train_batch_size=128, val_batch_size=256, cutout_size=None, num_workers=6):

    if dataset in ['imagenet', "imagenette"]:
        if dataset == "imagenet":
            traindir, valdir = get_imagenet_datapath()
        else:
            traindir, valdir = get_imagenette_datapath()

        train_transforms = [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        ]
        if cutout_size is not None:
            train_transforms.append(Cutout(n_holes=2, length=cutout_size))

        train_dataset = torchdatasets.ImageFolder(traindir, transforms.Compose(train_transforms))
        val_dataset = torchdatasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
        ]))

    elif dataset in ["cifar10", "cifar100", "svhn"]:
        train_transforms = [
            transforms.RandomCrop(32, padding=4),
            transforms.Resize(224, interpolation=InterpolationMode.BICUBIC),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
        if cutout_size is not None:
            train_transforms.append(Cutout(n_holes=2, length=cutout_size))
        val_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.ToTensor(),
        ])
        if dataset == "cifar10":
            dataset_path = get_cifar10_datapath()
            train_dataset = torchdatasets.CIFAR10(dataset_path, train=True, transform=transforms.Compose(train_transforms))
            val_dataset = torchdatasets.CIFAR10(dataset_path, train= False, transform=val_transform)
        elif dataset == "cifar100":
            dataset_path = get_cifar100_datapath()
            train_dataset = torchdatasets.CIFAR100(dataset_path, train=True, transform=transforms.Compose(train_transforms))
            val_dataset = torchdatasets.CIFAR100(dataset_path, train= False, transform=val_transform)
        else:
            dataset_path = get_svhn_datapath()
            train_dataset = torchdatasets.SVHN(dataset_path, split= 'train', transform=transforms.Compose(train_transforms))
            val_dataset = torchdatasets.SVHN(dataset_path, split='test', transform=val_transform)
    else:
        raise ValueError("An Invalid dataset is provided!")

    train_loader = to_dataloader(train_dataset, batchsize=train_batch_size, num_workers=num_workers, shuffle=True)
    val_loader = to_dataloader(val_dataset, batchsize=val_batch_size, num_workers=num_workers, shuffle=False)
    return train_loader, val_loader


def get_imagenet_datapath():
    traindir = "/home/datasets/ImageNet/ILSVRC2012/images/train/"
    valdir = "/home/datasets/ImageNet/ILSVRC2012/images/val/"   
    return traindir, valdir


def get_cifar10_datapath():
    return "/home/datasets/cifar10"


def get_cifar100_datapath():
    return "/home/datasets/cifar100"


def get_imagenette_datapath():
    traindir = "/home/datasets/imagenette2/train"
    valdir = "/home/datasets/imagenette2/val"
    return traindir, valdir


def get_svhn_datapath():
    return "/home/datasets/svhn"


def set_checkpoint_path(checkpoint_name, dataset, arch):
    checkpoint_prefix = "/home/model_weights/"

    checkpoint_path = os.path.join(checkpoint_prefix, dataset, arch)
    if not os.path.exists(checkpoint_path):
        os.makedirs(checkpoint_path)
    return os.path.join(checkpoint_path, checkpoint_name + ".pth.tar")


def set_result_path(checkpoint_name, dataset, arch):
    result_path_prefix = "/home/results/"

    result_path = os.path.join(result_path_prefix, dataset, arch, checkpoint_name)
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    return result_path


def to_dataloader(data_set, batchsize=100, num_workers=6, shuffle=False):
    return torch.utils.data.DataLoader(
        data_set,
        batch_size=batchsize,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=True,
    )


def get_classification_loss(output, labels):
    if len(output) > len(labels):
        y = labels.repeat(len(output)//len(labels))
        classification_loss = F.cross_entropy(output, y)
    else:
        classification_loss = F.cross_entropy(output, labels)
    return classification_loss


def save_state_dict(epoch, optimizer, classifier, checkpoint_path):
    out_dict = {"epoch": epoch, "optimizer": optimizer, "state_dict": classifier.state_dict()}
    torch.save(out_dict, checkpoint_path)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


